
<!DOCTYPE html>

<html lang="zh">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

    <title>4.1 ResNet &#8212; 深入浅出PyTorch</title>
    
  <!-- Loaded before other Sphinx assets -->
  <link href="../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link href="../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">

    
  <link rel="stylesheet"
    href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">

    <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
    <link rel="stylesheet" href="../_static/styles/sphinx-book-theme.css?digest=62ba249389abaaa9ffc34bf36a076bdc1d65ee18" type="text/css" />
    <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
    <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
    <link rel="stylesheet" type="text/css" href="../_static/plot_directive.css" />
    
  <!-- Pre-loaded scripts that we'll load fully later -->
  <link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">

    <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
    <script src="../_static/jquery.js"></script>
    <script src="../_static/underscore.js"></script>
    <script src="../_static/doctools.js"></script>
    <script>let toggleHintShow = 'Click to show';</script>
    <script>let toggleHintHide = 'Click to hide';</script>
    <script>let toggleOpenOnPrint = 'true';</script>
    <script src="../_static/togglebutton.js"></script>
    <script src="../_static/scripts/sphinx-book-theme.js?digest=f31d14ad54b65d19161ba51d4ffff3a77ae00456"></script>
    <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
    <link rel="index" title="索引" href="../genindex.html" />
    <link rel="search" title="搜索" href="../search.html" />
    <link rel="next" title="&lt;no title&gt;" href="4.3%20DenseNet.html" />
    <link rel="prev" title="第四章：PyTorch基础实战" href="index.html" />
    <meta name="viewport" content="width=device-width, initial-scale=1" />
    <meta name="docsearch:language" content="zh">
    

    <!-- Google Analytics -->
    
  </head>
  <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="60">
<!-- Checkboxes to toggle the left sidebar -->
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation" aria-label="Toggle navigation sidebar">
<label class="overlay overlay-navbar" for="__navigation">
    <div class="visually-hidden">Toggle navigation sidebar</div>
</label>
<!-- Checkboxes to toggle the in-page toc -->
<input type="checkbox" class="sidebar-toggle" name="__page-toc" id="__page-toc" aria-label="Toggle in-page Table of Contents">
<label class="overlay overlay-pagetoc" for="__page-toc">
    <div class="visually-hidden">Toggle in-page Table of Contents</div>
</label>
<!-- Headers at the top -->
<div class="announcement header-item noprint"></div>
<div class="header header-item noprint"></div>

    
    <div class="container-fluid" id="banner"></div>

    

    <div class="container-xl">
      <div class="row">
          
<!-- Sidebar -->
<div class="bd-sidebar noprint" id="site-navigation">
    <div class="bd-sidebar__content">
        <div class="bd-sidebar__top"><div class="navbar-brand-box">
    <a class="navbar-brand text-wrap" href="../index.html">
      
      
      
      <h1 class="site-logo" id="site-title">深入浅出PyTorch</h1>
      
    </a>
</div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
  <i class="icon fas fa-search"></i>
  <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
</form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
    <div class="bd-toc-item active">
        <p aria-level="2" class="caption" role="heading">
 <span class="caption-text">
  目录
 </span>
</p>
<ul class="current nav bd-sidenav">
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/index.html">
   第零章：前置知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  <label for="toctree-checkbox-1">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.1%20%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E7%AE%80%E5%8F%B2.html">
     人工智能简史
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.2%20%E8%AF%84%E4%BB%B7%E6%8C%87%E6%A0%87.html">
     模型评价指标
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.3%20%E5%B8%B8%E7%94%A8%E5%8C%85%E7%9A%84%E5%AD%A6%E4%B9%A0.html">
     常用包的学习
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.4%20Jupyter%E7%9B%B8%E5%85%B3%E6%93%8D%E4%BD%9C.html">
     Jupyter notebook/Lab 简述
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/index.html">
   第一章：PyTorch的简介和安装
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  <label for="toctree-checkbox-2">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.1%20PyTorch%E7%AE%80%E4%BB%8B.html">
     1.1 PyTorch简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.2%20PyTorch%E7%9A%84%E5%AE%89%E8%A3%85.html">
     1.2 PyTorch的安装
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.3%20PyTorch%E7%9B%B8%E5%85%B3%E8%B5%84%E6%BA%90.html">
     1.3 PyTorch相关资源
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/index.html">
   第二章：PyTorch基础知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  <label for="toctree-checkbox-3">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.1%20%E5%BC%A0%E9%87%8F.html">
     2.1 张量
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.2%20%E8%87%AA%E5%8A%A8%E6%B1%82%E5%AF%BC.html">
     2.2 自动求导
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.3%20%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97%E7%AE%80%E4%BB%8B.html">
     2.3 并行计算简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.4%20AI%E7%A1%AC%E4%BB%B6%E5%8A%A0%E9%80%9F%E8%AE%BE%E5%A4%87.html">
     AI硬件加速设备
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/index.html">
   第三章：PyTorch的主要组成模块
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  <label for="toctree-checkbox-4">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.1%20%E6%80%9D%E8%80%83%EF%BC%9A%E5%AE%8C%E6%88%90%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%BF%85%E8%A6%81%E9%83%A8%E5%88%86.html">
     3.1 思考：完成深度学习的必要部分
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.2%20%E5%9F%BA%E6%9C%AC%E9%85%8D%E7%BD%AE.html">
     3.2 基本配置
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html">
     3.3 数据读入
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.4%20%E6%A8%A1%E5%9E%8B%E6%9E%84%E5%BB%BA.html">
     3.4 模型构建
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html">
     3.5 模型初始化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.6%20%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     3.6 损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.7%20%E8%AE%AD%E7%BB%83%E4%B8%8E%E8%AF%84%E4%BC%B0.html">
     3.7 训练和评估
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     3.8 可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.9%20%E4%BC%98%E5%8C%96%E5%99%A8.html">
     3.9 PyTorch优化器
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 current active has-children">
  <a class="reference internal" href="index.html">
   第四章：PyTorch基础实战
  </a>
  <input checked="" class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  <label for="toctree-checkbox-5">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul class="current">
   <li class="toctree-l2 current active">
    <a class="current reference internal" href="#">
     4.1 ResNet
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="4.4%20FashionMNIST%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     基础实战——FashionMNIST时装分类
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/index.html">
   第五章：PyTorch模型定义
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
  <label for="toctree-checkbox-6">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.1%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E7%9A%84%E6%96%B9%E5%BC%8F.html">
     5.1 PyTorch模型定义的方式
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.2%20%E5%88%A9%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%9D%97%E5%BF%AB%E9%80%9F%E6%90%AD%E5%BB%BA%E5%A4%8D%E6%9D%82%E7%BD%91%E7%BB%9C.html">
     5.2 利用模型块快速搭建复杂网络
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html">
     5.3 PyTorch修改模型
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.4%20PyTorh%E6%A8%A1%E5%9E%8B%E4%BF%9D%E5%AD%98%E4%B8%8E%E8%AF%BB%E5%8F%96.html">
     5.4 PyTorch模型保存与读取
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html">
   第六章：PyTorch进阶训练技巧
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/>
  <label for="toctree-checkbox-7">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.1%20%E8%87%AA%E5%AE%9A%E4%B9%89%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     6.1 自定义损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.2%20%E5%8A%A8%E6%80%81%E8%B0%83%E6%95%B4%E5%AD%A6%E4%B9%A0%E7%8E%87.html">
     6.2 动态调整学习率
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-torchvision.html">
     6.3 模型微调-torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-timm.html">
     6.3 模型微调 - timm
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.4%20%E5%8D%8A%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83.html">
     6.4 半精度训练
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html">
     6.5 数据增强-imgaug
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.6%20%E4%BD%BF%E7%94%A8argparse%E8%BF%9B%E8%A1%8C%E8%B0%83%E5%8F%82.html">
     6.6 使用argparse进行调参
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/index.html">
   第七章：PyTorch可视化
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/>
  <label for="toctree-checkbox-8">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.1%20%E5%8F%AF%E8%A7%86%E5%8C%96%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84.html">
     7.1 可视化网络结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.2%20CNN%E5%8D%B7%E7%A7%AF%E5%B1%82%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     7.2 CNN可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.3%20%E4%BD%BF%E7%94%A8TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.3 使用TensorBoard可视化训练过程
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.4%20%E4%BD%BF%E7%94%A8wandb%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.4 使用wandb可视化训练过程
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/index.html">
   第八章：PyTorch生态简介
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/>
  <label for="toctree-checkbox-9">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.1%20%E6%9C%AC%E7%AB%A0%E7%AE%80%E4%BB%8B.html">
     8.1 本章简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.2%20%E5%9B%BE%E5%83%8F%20-%20torchvision.html">
     8.2 torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.3%20%E8%A7%86%E9%A2%91%20-%20PyTorchVideo.html">
     8.3 PyTorchVideo简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.4%20%E6%96%87%E6%9C%AC%20-%20torchtext.html">
     8.4 torchtext简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.5%20%E9%9F%B3%E9%A2%91%20-%20torchaudio.html">
     8.5 torchaudio简介
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/index.html">
   第九章：PyTorch的模型部署
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/>
  <label for="toctree-checkbox-10">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/9.1%20%E4%BD%BF%E7%94%A8ONNX%E8%BF%9B%E8%A1%8C%E9%83%A8%E7%BD%B2%E5%B9%B6%E6%8E%A8%E7%90%86.html">
     9.1 使用ONNX进行部署并推理
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/index.html">
   第十章：常见代码解读
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/>
  <label for="toctree-checkbox-11">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/10.1%20%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     10.1 图像分类简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/10.2%20%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B.html">
     目标检测简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/10.3%20%E5%9B%BE%E5%83%8F%E5%88%86%E5%89%B2.html">
     10.3 图像分割简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/ResNet%E6%BA%90%E7%A0%81%E8%A7%A3%E8%AF%BB.html">
     ResNet源码解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/RNN%E8%AF%A6%E8%A7%A3%E5%8F%8A%E5%85%B6%E5%AE%9E%E7%8E%B0.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/LSTM%E8%A7%A3%E8%AF%BB%E5%8F%8A%E5%AE%9E%E6%88%98.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/Transformer%20%E8%A7%A3%E8%AF%BB.html">
     Transformer 解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/ViT%E8%A7%A3%E8%AF%BB.html">
     ViT解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%8D%81%E7%AB%A0/Swin-Transformer%E8%A7%A3%E8%AF%BB.html">
     Swin Transformer解读
    </a>
   </li>
  </ul>
 </li>
</ul>

    </div>
</nav></div>
        <div class="bd-sidebar__bottom">
             <!-- To handle the deprecated key -->
            
            <div class="navbar_extra_footer">
            Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a>
            </div>
            
        </div>
    </div>
    <div id="rtd-footer-container"></div>
</div>


          


          
<!-- A tiny helper pixel to detect if we've scrolled -->
<div class="sbt-scroll-pixel-helper"></div>
<!-- Main content -->
<div class="col py-0 content-container">
    
    <div class="header-article row sticky-top noprint">
        



<div class="col py-1 d-flex header-article-main">
    <div class="header-article__left">
        
        <label for="__navigation"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="right"
title="Toggle navigation"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-bars"></i>
  </span>

</label>

        
    </div>
    <div class="header-article__right">
<button onclick="toggleFullScreen()"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="bottom"
title="Fullscreen mode"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-expand"></i>
  </span>

</button>

<div class="menu-dropdown menu-dropdown-repository-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Source repositories">
      <i class="fab fa-github"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Source repository"
>
  

<span class="headerbtn__icon-container">
  <i class="fab fa-github"></i>
  </span>
<span class="headerbtn__text-container">repository</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/issues/new?title=Issue%20on%20page%20%2F第四章/4.1 ResNet.html&body=Your%20issue%20content%20here."
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Open an issue"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-lightbulb"></i>
  </span>
<span class="headerbtn__text-container">open issue</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/edit/master/第四章/4.1 ResNet.md"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Edit this page"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-pencil-alt"></i>
  </span>
<span class="headerbtn__text-container">suggest edit</span>
</a>

      </li>
      
    </ul>
  </div>
</div>

<div class="menu-dropdown menu-dropdown-download-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Download this page">
      <i class="fas fa-download"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="../_sources/第四章/4.1 ResNet.md.txt"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Download source file"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file"></i>
  </span>
<span class="headerbtn__text-container">.md</span>
</a>

      </li>
      
      <li>
        
<button onclick="printPdf(this)"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="left"
title="Print to PDF"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file-pdf"></i>
  </span>
<span class="headerbtn__text-container">.pdf</span>
</button>

      </li>
      
    </ul>
  </div>
</div>
<label for="__page-toc"
  class="headerbtn headerbtn-page-toc"
  
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-list"></i>
  </span>

</label>

    </div>
</div>

<!-- Table of contents -->
<div class="col-md-3 bd-toc show noprint">
    <div class="tocsection onthispage pt-5 pb-3">
        <i class="fas fa-list"></i> Contents
    </div>
    <nav id="bd-toc-nav" aria-label="Page">
        <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id1">
   1 基本介绍
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   2 源码解读
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     2.1 卷积核的封装
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id4">
     2.2 基本模块的设计
    </a>
    <ul class="nav section-nav flex-column">
     <li class="toc-h4 nav-item toc-entry">
      <a class="reference internal nav-link" href="#shortcut-connection">
       2.2.1 Shortcut Connection.
      </a>
     </li>
     <li class="toc-h4 nav-item toc-entry">
      <a class="reference internal nav-link" href="#basicblock">
       2.2.2 BasicBlock
      </a>
     </li>
     <li class="toc-h4 nav-item toc-entry">
      <a class="reference internal nav-link" href="#bottleneck">
       2.2.3 BottleNeck
      </a>
     </li>
    </ul>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id5">
     2.3 网络整体结构
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id6">
   3 总结
  </a>
 </li>
</ul>

    </nav>
</div>
    </div>
    <div class="article row">
        <div class="col pl-md-3 pl-lg-5 content-container">
            <!-- Table of contents that is only displayed when printing the page -->
            <div id="jb-print-docs-body" class="onlyprint">
                <h1>4.1 ResNet</h1>
                <!-- Table of contents -->
                <div id="print-main-content">
                    <div id="jb-print-toc">
                        
                        <div>
                            <h2> Contents </h2>
                        </div>
                        <nav aria-label="Page">
                            <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id1">
   1 基本介绍
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   2 源码解读
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     2.1 卷积核的封装
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id4">
     2.2 基本模块的设计
    </a>
    <ul class="nav section-nav flex-column">
     <li class="toc-h4 nav-item toc-entry">
      <a class="reference internal nav-link" href="#shortcut-connection">
       2.2.1 Shortcut Connection.
      </a>
     </li>
     <li class="toc-h4 nav-item toc-entry">
      <a class="reference internal nav-link" href="#basicblock">
       2.2.2 BasicBlock
      </a>
     </li>
     <li class="toc-h4 nav-item toc-entry">
      <a class="reference internal nav-link" href="#bottleneck">
       2.2.3 BottleNeck
      </a>
     </li>
    </ul>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id5">
     2.3 网络整体结构
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id6">
   3 总结
  </a>
 </li>
</ul>

                        </nav>
                    </div>
                </div>
            </div>
            <main id="main-content" role="main">
                
              <div>
                
  <section class="tex2jax_ignore mathjax_ignore" id="resnet">
<h1>4.1 ResNet<a class="headerlink" href="#resnet" title="永久链接至标题">#</a></h1>
<p>残差神经网络(ResNet)是由微软研究院的何恺明、张祥雨、任少卿、孙剑等人提出的。它的主要贡献是发现了在增加网络层数的过程中，随着训练精度(Training accuracy)逐渐趋于饱和，继续增加层数，training accuracy 就会出现下降的现象，而这种下降不是由过拟合造成的。他们将这一现象称之为“退化现象（Degradation）”，并针对退化现象发明了 “快捷连接（Shortcut connection）”，极大的消除了深度过大的神经网络训练困难问题。神经网络的“深度”首次突破了100层、最大的神经网络甚至超过了1000层。（在此，向已故的孙剑博士表示崇高的敬意）</p>
<p>通过本文你将学习到：</p>
<ul class="simple">
<li><p>梯度消失/爆炸的简介</p></li>
<li><p>代码里面为什么要设计BasicBlock和Bottleneck两种结构</p></li>
<li><p>代码里面的expansion作用</p></li>
</ul>
<section id="id1">
<h2>1 基本介绍<a class="headerlink" href="#id1" title="永久链接至标题">#</a></h2>
<p>随着卷积神经网络的出现，人们发现多层卷积或者全连接网络的效果大于单层卷积或者全连接网络。于是很多人潜意识认为网络的层数越多，其效果就会越好。但是当时微软研究院的何恺明、张祥雨、任少卿、孙剑等人发现加深网络的深度后，整个网络的效果反而变差了许多。他们认为很深的网络无法训练的原因可能是网络在信息传递的时候或多或少会存在信息丢失，损耗等问题，同时还可能出现梯度消失或者梯度爆炸现象。针对这个问题，他们提出了ResNet以期望解决这个问题，ResNet的出现也让神经网络逐渐真正走向深度神经网络。ResNet最大的贡献在于添加了shortcut connection将输入直接连接到后面的层，一定程度缓解了梯度消失和梯度爆炸并提高了深度神经网络的效果。接下来我们详细的解释一下<strong>梯度消失</strong>和<strong>梯度爆炸</strong>。</p>
<p>梯度消失和梯度爆炸的根源主要是因为深度神经网络结构以及反向传播算法，目前优化神经网络的方法都是基于反向传播的思想，即根据损失函数计算的误差通过反向传播的方式，指导深度网络权值的更新。误差梯度是神经网络训练过程中计算的方向和数量，用于以正确的方向和合适的量更新网络权重。 在深层网络或循环神经网络中，误差梯度可在更新中累积，变成非常大的梯度，然后导致网络权重的大幅更新，并因此使网络变得不稳定。在极端情况下，权重的值变得非常大，以至于溢出，导致 NaN 值。 <strong>网络层之间的梯度（值大于 1.0）重复相乘导致的指数级增长会产生梯度爆炸。</strong> 在深度多层感知机网络中，梯度爆炸会引起网络不稳定，最好的结果是无法从训练数据中学习，而最坏的结果是出现无法再更新的 NaN 权重值。</p>
<p>而在某些情况下，梯度会变得非常小， <strong>网络层之间的梯度（值小于 1.0）重复相乘导致的指数级变小会产生梯度消失</strong>。在最坏的情况下，这可能会完全停止神经网络的进一步训练。例如，传统的激活函数(如双曲正切函数)具有范围(0,1)内的梯度，反向传播通过链式法则计算梯度。这样做的效果是，用这些小数字的n乘以n来计算n层网络中“前端”层的梯度，这意味着梯度(误差信号)随n呈指数递减，而前端层的训练非常缓慢。最终导致更新停滞。</p>
</section>
<section id="id2">
<h2>2 源码解读<a class="headerlink" href="#id2" title="永久链接至标题">#</a></h2>
<p>为了帮助大家对ResNet有更好的理解，我们使用<a class="reference external" href="https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py">torchvision的ResNet源码</a>进行解读。</p>
<section id="id3">
<h3>2.1 卷积核的封装<a class="headerlink" href="#id3" title="永久链接至标题">#</a></h3>
<p>在代码的开始，首先封装了3x3和1x1的卷积核，这样可以增加代码的可读性。除了这种代码写法外，还有许多深度学习代码在开始也会将卷积层，激活函数层和BN层封装在一起，同样是为了增加代码的可读性。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">conv3x3</span><span class="p">(</span><span class="n">in_planes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_planes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dilation</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">:</span>
    <span class="sd">&quot;&quot;&quot;3x3 convolution with padding&quot;&quot;&quot;</span>
    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span>
        <span class="n">in_planes</span><span class="p">,</span>
        <span class="n">out_planes</span><span class="p">,</span>
        <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
        <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span>
        <span class="n">padding</span><span class="o">=</span><span class="n">dilation</span><span class="p">,</span>
        <span class="n">groups</span><span class="o">=</span><span class="n">groups</span><span class="p">,</span>
        <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
        <span class="n">dilation</span><span class="o">=</span><span class="n">dilation</span><span class="p">,</span>
    <span class="p">)</span>


<span class="k">def</span> <span class="nf">conv1x1</span><span class="p">(</span><span class="n">in_planes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_planes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">:</span>
    <span class="sd">&quot;&quot;&quot;1x1 convolution&quot;&quot;&quot;</span>
    <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_planes</span><span class="p">,</span> <span class="n">out_planes</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</pre></div>
</div>
</section>
<section id="id4">
<h3>2.2 基本模块的设计<a class="headerlink" href="#id4" title="永久链接至标题">#</a></h3>
<p>ResNet网络是由很多相同的模块堆叠起来的，为了保证代码具有可读性和可扩展性，ResNet在设计时采用了模块化设计，针对不同大小的ResNet，书写了BasicBlock和BottleNeck两个基本模块。这种模块化的设计在现在许多常见的深度学习代码中我们可以经常看到。</p>
<p>ResNet常见的大小有下图的ResNet-18，ResNet-34，ResNet-50、ResNet-101和ResNet-152，其中网络后面的数字代表的是网络的层数。</p>
<p><img alt="expansion" src="../_images/expansion1.jpg" /></p>
<p>为了帮助大家更好的理解，我们以ResNet101为例。</p>
<table class="colwidths-auto table">
<thead>
<tr class="row-odd"><th class="head"><p>layer_name</p></th>
<th class="head"><p>次数</p></th>
</tr>
</thead>
<tbody>
<tr class="row-even"><td><p>conv1</p></td>
<td><p>卷积1次</p></td>
</tr>
<tr class="row-odd"><td><p>conv2_x</p></td>
<td><p>卷积3 x 3 = 9次</p></td>
</tr>
<tr class="row-even"><td><p>conv3_x</p></td>
<td><p>卷积4 x 3 = 12次</p></td>
</tr>
<tr class="row-odd"><td><p>conv4_x</p></td>
<td><p>卷积23 x 3 = 69次</p></td>
</tr>
<tr class="row-even"><td><p>conv5_x</p></td>
<td><p>卷积3 x 3 = 9次</p></td>
</tr>
<tr class="row-odd"><td><p>fc</p></td>
<td><p>average pool 1次</p></td>
</tr>
<tr class="row-even"><td><p>合计</p></td>
<td><p>1 + 9 + 12 + 69 + 9 + 1 = 101次</p></td>
</tr>
</tbody>
</table>
<p>观察上面各个ResNet的模块，我们可以发现ResNet-18和ResNet-34每一层内，数据的大小不会发生变化，但是ResNet-50、ResNet-101和ResNet-152中的每一层内输入和输出的channel数目不一样，输出的channel扩大为输入channel的4倍，除此之外，每一层的卷积的大小也变换为1，3，1的结构。基于这个发现，我们可以将ResNet-18和ResNet-34的构成模块当作一类，ResNet-50、ResNet-101和ResNet-152这三类网络的构成模块当作一类。事实上，torchvision的源码也是基于这种设计思想，使用如下图的BasicBlock（左）和BottleNeck（右）模块，并且为了控制输入和输出的通道数目的变化，在代码中输出的通道维度也通过expansion进行控制，两个block类输入一个通道为in_planes维的度特征图，输出一个planes*block.expansion维的特征图，其中planes的数目大小等于in_planes。除此以外，代码右侧的曲线就是本文最重要的shortcut支路，该支路上的downsample操作是为了对shortcut支路进行大小或维度上的调整，以希望执行相加操作</p>
<p><img alt="block" src="../_images/block1.jpg" /></p>
<section id="shortcut-connection">
<h4>2.2.1 Shortcut Connection.<a class="headerlink" href="#shortcut-connection" title="永久链接至标题">#</a></h4>
<p>这里再分析一下shortcut connection：</p>
<p><img alt="shortcut" src="../_images/shortcut1.jpg" /></p>
<p>shortcut connection也就是所谓的“抄近道”，它有两种方式，其一为同等维度的映射，即输入输出直接相加（即上图中的F(x) + x），另一种为不同维度的映射，这时候就需要给x补充一个线性映射来匹配维度。</p>
<p>比如下面这个图：</p>
<img src="../../source/第四章/figures/shortcut2.jpg" >
<p>左：VGG-19模型，作为参考。 中：一个有34个参数层的普通网络。 右：一个有34个参数层的残差网络（即resnet34）</p>
<p>在上图最右侧的路径中，我们可以很明显的看到shortcut connection加入了网络之中，同时，图中也很明显的可以看到，实线部分就是进行了单纯的F(x)+x操作，而虚线部分，第一个卷积层的stride是2（那个/2的意思就是stride是2）；同时注意到深度也发生了变换，channel数目增加一倍（扩大两倍），这样F(x)的分辨率比x小一半，厚度比x大一倍。在这样的shortcut connection中，就需要补充线性映射来增加维度。在ResNet中，作者使用了1 x 1的卷积核来达到这个目的。</p>
<p>另外，论文中又提到说：“……where both designs have similar time complexity.” 既然BasicBlock和Bottleneck二者的时间复杂度类似，那么为什么还要额外设计一个Bottleneck结构呢？</p>
<p>根据前面的叙述我们知道，BasicBlock结构比传统的卷积结构多了一个shortcut支路，用于传递低层的信息，使得网络能够训练地很深。而BottleNeck先通过一个1x1的卷积减少通道数，使得中间卷积的通道数减少为1/4；中间的普通卷积做完卷积后输出通道数等于输入通道数；第三个卷积用于恢复通道数，使得BottleNeck的输出通道数等于BottleNeck的输入通道数。**换句话说，这两个1x1卷积有效地减少了卷积的参数个数和计算量，同时减少了中间特征图的通道数，使单个Block消耗的显存更少，在较深的网络中BottleNeck会在参数上更加节约，这样可以有利于构建层数更多的网络，同时还能保持性能的提升。**所以resnet50, resnet101和resnet152使用了另外设计的BottleNeck结构。</p>
</section>
<section id="basicblock">
<h4>2.2.2 BasicBlock<a class="headerlink" href="#basicblock" title="永久链接至标题">#</a></h4>
<p>BasicBlock模块用来构建resnet18和resnet34</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">BasicBlock</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="n">expansion</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">inplanes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">planes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">downsample</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">base_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span>
        <span class="n">dilation</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">norm_layer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="k">if</span> <span class="n">norm_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">norm_layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span>
        <span class="k">if</span> <span class="n">groups</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">base_width</span> <span class="o">!=</span> <span class="mi">64</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;BasicBlock only supports groups=1 and base_width=64&quot;</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">dilation</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;Dilation &gt; 1 not supported in BasicBlock&quot;</span><span class="p">)</span>
        <span class="c1"># Both self.conv1 and self.downsample layers downsample the input when stride != 1</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">conv3x3</span><span class="p">(</span><span class="n">inplanes</span><span class="p">,</span> <span class="n">planes</span><span class="p">,</span> <span class="n">stride</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">planes</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">relu</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">conv3x3</span><span class="p">(</span><span class="n">planes</span><span class="p">,</span> <span class="n">planes</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">planes</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">downsample</span> <span class="o">=</span> <span class="n">downsample</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">=</span> <span class="n">stride</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
        <span class="n">identity</span> <span class="o">=</span> <span class="n">x</span>  <span class="c1"># x  给自己先备份一份</span>

        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>  <span class="c1"># 对x做卷积 </span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>  <span class="c1"># 对x归一化 </span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>  <span class="c1"># 对x用激活函数</span>

        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>  <span class="c1"># 对x做卷积</span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>  <span class="c1"># 归一化</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">downsample</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">identity</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">downsample</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

        <span class="n">out</span> <span class="o">+=</span> <span class="n">identity</span>  <span class="c1"># 进行downsample</span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">out</span>
</pre></div>
</div>
</section>
<section id="bottleneck">
<h4>2.2.3 BottleNeck<a class="headerlink" href="#bottleneck" title="永久链接至标题">#</a></h4>
<p>BottleNeck模块用来构建resnet50，resnet101和resnet152</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Bottleneck</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="c1"># Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)</span>
    <span class="c1"># while original implementation places the stride at the first 1x1 convolution(self.conv1)</span>
    <span class="c1"># according to &quot;Deep residual learning for image recognition&quot;https://arxiv.org/abs/1512.03385.</span>
    <span class="c1"># This variant is also known as ResNet V1.5 and improves accuracy according to</span>
    <span class="c1"># https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.</span>

    <span class="n">expansion</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span>  <span class="c1"># 对输出通道进行倍增</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">inplanes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">planes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">downsample</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">base_width</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span>
        <span class="n">dilation</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">norm_layer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="k">if</span> <span class="n">norm_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">norm_layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span>
        <span class="n">width</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">planes</span> <span class="o">*</span> <span class="p">(</span><span class="n">base_width</span> <span class="o">/</span> <span class="mf">64.0</span><span class="p">))</span> <span class="o">*</span> <span class="n">groups</span>
        <span class="c1"># Both self.conv2 and self.downsample layers downsample the input when stride != 1</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">conv1x1</span><span class="p">(</span><span class="n">inplanes</span><span class="p">,</span> <span class="n">width</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">width</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">conv3x3</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">groups</span><span class="p">,</span> <span class="n">dilation</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">width</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">conv3</span> <span class="o">=</span> <span class="n">conv1x1</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">planes</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">expansion</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">bn3</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">planes</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">expansion</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">relu</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">downsample</span> <span class="o">=</span> <span class="n">downsample</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">=</span> <span class="n">stride</span>

        <span class="c1"># Bottleneckd forward函数和BasicBlock类似，不再额外注释</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
        <span class="n">identity</span> <span class="o">=</span> <span class="n">x</span>

        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>

        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn2</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>

        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv3</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn3</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">downsample</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">identity</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">downsample</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

        <span class="n">out</span> <span class="o">+=</span> <span class="n">identity</span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">out</span>
</pre></div>
</div>
<p>我们在这里再对代码中<strong>expansion</strong>的作用做一个说明，我们可以重新回顾一下下面这张图。</p>
<p><img alt="" src="../_images/expansion1.jpg" /></p>
<p>expansion简单来说就是对输出通道的倍乘。在BasicBlock和BottleNeck中，“__init__”函数中有三个比较关键的参数：inplanes,planes和stride，这三者分别表示输入的通道数，输出的通道数和步幅。在两个模块中，__init__传入的planes都是64,128,156,512，但我们观察上面的表格，会发现对于ResNet-50，ResNet-101和ResNet-152而言，它们需要的输出通道应该为256,512,1024,2048才对。因此在这里设置expansion=4，对应上面BottleNeck代码中的30行和31行，将每一个planes都乘上这个expansion，就得到了需要的通道数；而对于ResNet-18和ResNet-34而言，输入通道和输出通道的维度上没有发生变化，因此expansion也设置为1。</p>
</section>
</section>
<section id="id5">
<h3>2.3 网络整体结构<a class="headerlink" href="#id5" title="永久链接至标题">#</a></h3>
<p>在定义好最基本的Bottlenneck和BasicBlock后，我们就可以构建ResNet网络了。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">ResNet</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">block</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">BasicBlock</span><span class="p">,</span> <span class="n">Bottleneck</span><span class="p">]],</span> <span class="c1"># 选择基本模块</span>
        <span class="n">layers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="c1"># 每一层block的数目构成 -&gt; [3,4,6,3]</span>
        <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1000</span><span class="p">,</span> <span class="c1"># 分类数目</span>
        <span class="n">zero_init_residual</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># 初始化</span>
        
        <span class="c1">#######其他卷积构成，与本文ResNet无关######</span>
        <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">width_per_group</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span>
        <span class="n">replace_stride_with_dilation</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">bool</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
        <span class="c1">#########################################</span>
        
        <span class="n">norm_layer</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># norm层</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="n">_log_api_usage_once</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">norm_layer</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">norm_layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_norm_layer</span> <span class="o">=</span> <span class="n">norm_layer</span>
		
        <span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span> <span class="o">=</span> <span class="mi">64</span> <span class="c1"># 输入通道</span>
        
        <span class="c1">#######其他卷积构成，与本文ResNet无关######</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># 空洞卷积</span>
        <span class="k">if</span> <span class="n">replace_stride_with_dilation</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="c1"># each element in the tuple indicates if we should replace</span>
            <span class="c1"># the 2x2 stride with a dilated convolution instead</span>
            <span class="n">replace_stride_with_dilation</span> <span class="o">=</span> <span class="p">[</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">]</span>
        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">replace_stride_with_dilation</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
                <span class="s2">&quot;replace_stride_with_dilation should be None &quot;</span>
                <span class="sa">f</span><span class="s2">&quot;or a 3-element tuple, got </span><span class="si">{</span><span class="n">replace_stride_with_dilation</span><span class="si">}</span><span class="s2">&quot;</span>
            <span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">groups</span> <span class="o">=</span> <span class="n">groups</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">base_width</span> <span class="o">=</span> <span class="n">width_per_group</span>
        <span class="c1">#########################################</span>
        
        <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">relu</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">maxpool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MaxPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
        <span class="c1"># 通过_make_layer带到层次化设计的效果</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_make_layer</span><span class="p">(</span><span class="n">block</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>  <span class="c1"># 对应着conv2_x</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_make_layer</span><span class="p">(</span><span class="n">block</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">layers</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dilate</span><span class="o">=</span><span class="n">replace_stride_with_dilation</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>  <span class="c1"># 对应着conv3_x</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">layer3</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_make_layer</span><span class="p">(</span><span class="n">block</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="n">layers</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dilate</span><span class="o">=</span><span class="n">replace_stride_with_dilation</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>  <span class="c1"># 对应着conv4_x</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">layer4</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_make_layer</span><span class="p">(</span><span class="n">block</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="n">layers</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dilate</span><span class="o">=</span><span class="n">replace_stride_with_dilation</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>  <span class="c1"># 对应着conv5_x</span>
        <span class="c1"># 分类头</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">avgpool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">AdaptiveAvgPool2d</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">512</span> <span class="o">*</span> <span class="n">block</span><span class="o">.</span><span class="n">expansion</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
		
        <span class="c1"># 模型初始化</span>
        <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">modules</span><span class="p">():</span>
            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">):</span>
                <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">kaiming_normal_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;fan_out&quot;</span><span class="p">,</span> <span class="n">nonlinearity</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)</span>
            <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">)):</span>
                <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">constant_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
                <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">constant_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>

        <span class="c1"># Zero-initialize the last BN in each residual branch,</span>
        <span class="c1"># so that the residual branch starts with zeros, and each residual block behaves like an identity.</span>
        <span class="c1"># This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677</span>
        <span class="k">if</span> <span class="n">zero_init_residual</span><span class="p">:</span>
            <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">modules</span><span class="p">():</span>
                <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">Bottleneck</span><span class="p">)</span> <span class="ow">and</span> <span class="n">m</span><span class="o">.</span><span class="n">bn3</span><span class="o">.</span><span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                    <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">constant_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bn3</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>  <span class="c1"># type: ignore[arg-type]</span>
                <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">BasicBlock</span><span class="p">)</span> <span class="ow">and</span> <span class="n">m</span><span class="o">.</span><span class="n">bn2</span><span class="o">.</span><span class="n">weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                    <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">constant_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bn2</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>  <span class="c1"># type: ignore[arg-type]</span>
	<span class="c1"># 层次化设计</span>
    <span class="k">def</span> <span class="nf">_make_layer</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">block</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">BasicBlock</span><span class="p">,</span> <span class="n">Bottleneck</span><span class="p">]],</span> <span class="c1"># 基本构成模块选择</span>
        <span class="n">planes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>  <span class="c1"># 输入的通道</span>
        <span class="n">blocks</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="c1"># 模块数目</span>
        <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="c1"># 步长</span>
        <span class="n">dilate</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># 空洞卷积，与本文无关</span>
    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">:</span>
        <span class="n">norm_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_norm_layer</span>
        <span class="n">downsample</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># 是否采用下采样</span>
        <span class="c1">####################无关#####################</span>
        <span class="n">previous_dilation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span> 
        <span class="k">if</span> <span class="n">dilate</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">dilation</span> <span class="o">*=</span> <span class="n">stride</span>
            <span class="n">stride</span> <span class="o">=</span> <span class="mi">1</span>
        <span class="c1">#############################################</span>
        <span class="k">if</span> <span class="n">stride</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span> <span class="o">!=</span> <span class="n">planes</span> <span class="o">*</span> <span class="n">block</span><span class="o">.</span><span class="n">expansion</span><span class="p">:</span>
            <span class="n">downsample</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
                <span class="n">conv1x1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span><span class="p">,</span> <span class="n">planes</span> <span class="o">*</span> <span class="n">block</span><span class="o">.</span><span class="n">expansion</span><span class="p">,</span> <span class="n">stride</span><span class="p">),</span>
                <span class="n">norm_layer</span><span class="p">(</span><span class="n">planes</span> <span class="o">*</span> <span class="n">block</span><span class="o">.</span><span class="n">expansion</span><span class="p">),</span>
            <span class="p">)</span>
		
        <span class="c1"># 使用layers存储每个layer</span>
        <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
            <span class="n">block</span><span class="p">(</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span><span class="p">,</span> <span class="n">planes</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">downsample</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">base_width</span><span class="p">,</span> <span class="n">previous_dilation</span><span class="p">,</span> <span class="n">norm_layer</span>
            <span class="p">)</span>
        <span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span> <span class="o">=</span> <span class="n">planes</span> <span class="o">*</span> <span class="n">block</span><span class="o">.</span><span class="n">expansion</span>
        <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">blocks</span><span class="p">):</span>
            <span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
                <span class="n">block</span><span class="p">(</span>
                    <span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span><span class="p">,</span>
                    <span class="n">planes</span><span class="p">,</span>
                    <span class="n">groups</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">groups</span><span class="p">,</span>
                    <span class="n">base_width</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">base_width</span><span class="p">,</span>
                    <span class="n">dilation</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dilation</span><span class="p">,</span>
                    <span class="n">norm_layer</span><span class="o">=</span><span class="n">norm_layer</span><span class="p">,</span>
                <span class="p">)</span>
            <span class="p">)</span>
		<span class="c1"># 将layers通过nn.Sequential转化为网络</span>
        <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_forward_impl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
        <span class="c1"># See note [TorchScript super()]</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>  <span class="c1"># conv1   x shape [1 64 112 112]</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bn1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>   <span class="c1"># 归一化处理   </span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>  <span class="c1"># 激活函数</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">maxpool</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>  <span class="c1"># conv2_x的3x3 maxpool        x shape [1 64 56 56]</span>

        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># layer 1</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># layer 2</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># layer 3</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer4</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># layer 4</span>

        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">avgpool</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># 自适应池化</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> 
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># 分类</span>

        <span class="k">return</span> <span class="n">x</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_forward_impl</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> 
</pre></div>
</div>
<p>观察上述代码，我们不难看到，首先是一个7 x 7的卷积作用在输入的3维图片上，并输入一个64维的特征图（即self.inplanes的初始值），通过BatchNorm层，ReLU层，MaxPool层；然后经过_make_layer()函数构建的4层layer，最后经过一个AveragePooling层，再经过一个fc层得到分类输出。在网络搭建起来后，还对模型的参数(Conv2d、BatchNorm2d、last BN)进行了初始化。</p>
<p>而对于_make_layer函数，一个_make_layer()构建一个layer层，每一个layer层是上述两种基本模块的堆叠。输入参数中block代表该layer堆叠模块的类型，可选BasicBlock或者BottleNeck；blocks代表该layer中堆叠的block的数目；planes与该layer最终输出的维度数有关，注意最终输出的维度数为planes * block.expansion。除此之外， _make_layer()是用来生成残差块的，这就牵扯到它的第四个参数：stride，即卷积步幅。该函数中首先定义了如果stride不等于1或者维度不匹配（即输入通道不满足对应关系）的时候的downsample，然后对其进行一次BN操作。接着对inplanes和planes不一致的情况进行了一次downsample ，即将带downsample的block添加至layers。这样保证了x和out的维度一致，接下来通过一个循环添加了指定个数的Block，由于x已经维度一致了，这样添加的其他的Block就可以不用降维了，所以循环添加不含Downsample的Block。正如下面代码所示</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">if</span> <span class="n">stride</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span> <span class="o">!=</span> <span class="n">planes</span> <span class="o">*</span> <span class="n">block</span><span class="o">.</span><span class="n">expansion</span><span class="p">:</span>
    <span class="n">downsample</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
        <span class="n">conv1x1</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">inplanes</span><span class="p">,</span> <span class="n">planes</span> <span class="o">*</span> <span class="n">block</span><span class="o">.</span><span class="n">expansion</span><span class="p">,</span> <span class="n">stride</span><span class="p">),</span>
        <span class="n">norm_layer</span><span class="p">(</span><span class="n">planes</span> <span class="o">*</span> <span class="n">block</span><span class="o">.</span><span class="n">expansion</span><span class="p">),</span>
    <span class="p">)</span>
</pre></div>
</div>
<p>当一个layer包含多个block时，是通过向layers列表中依次加入每个block，来实现block的堆叠的。第一个block需要特殊处理，该block依据传入的self.inplanes, planes以及stride判断，可能含有downsample支路；这个block的输出维度是planes*block.expansion。紧接着便把self.inplanes更新为此值作为后续block的输入维度。后面的block的stride为默认值1，同时，由于输入为self.inplanes，输出为planes*block.expansion，而self.inplanes = planes * block.expansion，因此不会出现特征图大小或者尺寸不一致的情况，不可能出现downsample操作。</p>
</section>
</section>
<section id="id6">
<h2>3 总结<a class="headerlink" href="#id6" title="永久链接至标题">#</a></h2>
<p>与普通的网络相比，ResNet最大的优势就是引入了Shortcut这个支路，让某一层可以直接连接到后面的层，使得后面的层可以直接学习残差。传统的卷积层或全连接层在信息传递时，或多或少会存在信息丢失、损耗等问题。ResNet 在某种程度上解决了这个问题，通过直接将输入信息绕道传到输出，保护信息的完整性，整个网络则只需要学习输入、输出差别的那一部分，简化学习目标和难度。</p>
<p>ResNet的出现，在一定程度上解决了卷积神经网络随深度的增加，但是模型效果却变差的问题，用作者的话说，就是: “Our deep residual nets can easily enjoy accuracy gains from greatly increased depth, producing results substantially better than previous networks.”。原始的ResNet对于训练卷积神经网路做出了很大的贡献，但是同样也有着许多可以改进的地方。随着时代的发展，原版的ResNet在一次又一次的研究中得到了丰富和完善，衍生出了丰富的改进的模型，如ResNeXt。它提出了一种介于普通卷积核深度可分离卷积的这种策略：分组卷积。通过控制分组的数量（基数）来达到两种策略的平衡。分组卷积的思想是源自Inception，不同于Inception的需要人工设计每个分支，ResNeXt的每个分支的拓扑结构是相同的。最后再结合残差网络，得到的便是最终的ResNeXt。</p>
<p>除此之外，ResNet还有其它变体如Wider ResNet，DarkNet53等。它们的改进相对较大，尤其是DarkNet53，它和ResNet已经有很大不同了，只是使用到了残差连接从而复用特征而已。总而言之，ResNet是深度学习领域一个里程碑式的工作。。</p>
</section>
</section>


              </div>
              
            </main>
            <footer class="footer-article noprint">
                
    <!-- Previous / next buttons -->
<div class='prev-next-area'>
    <a class='left-prev' id="prev-link" href="index.html" title="上一页 页">
        <i class="fas fa-angle-left"></i>
        <div class="prev-next-info">
            <p class="prev-next-subtitle">上一页</p>
            <p class="prev-next-title">第四章：PyTorch基础实战</p>
        </div>
    </a>
    <a class='right-next' id="next-link" href="4.3%20DenseNet.html" title="下一页 页">
    <div class="prev-next-info">
        <p class="prev-next-subtitle">下一页</p>
        <p class="prev-next-title">&lt;no title&gt;</p>
    </div>
    <i class="fas fa-angle-right"></i>
    </a>
</div>
            </footer>
        </div>
    </div>
    <div class="footer-content row">
        <footer class="col footer"><p>
  
    By ZhikangNiu<br/>
  
      &copy; Copyright 2022, ZhikangNiu.<br/>
</p>
        </footer>
    </div>
    
</div>


      </div>
    </div>
  
  <!-- Scripts loaded after <body> so the DOM is not blocked -->
  <script src="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>


  </body>
</html>
